{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a115ca1-3933-453c-b964-75acf70d67ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "## tags_embedding_analysis.ipynb\n",
    "## author: Francesco Garassino"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6d57866-1ce9-499e-85ae-9af6b1f489d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import pandas as pd\n",
    "import re\n",
    "from scipy.cluster.hierarchy import dendrogram, linkage\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from sklearn.cluster import AgglomerativeClustering\n",
    "from sklearn.decomposition import PCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1717e895-d859-44a2-98b4-70d08cf728a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_speaker_labels(text):\n",
    "    return re.sub(r'\\s*SPEAKER\\s\\d+\\s\\d{1,2}:\\d{2}:\\d{2}', '', text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf72e20c-8ede-4798-9f5d-53dedd859b34",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define which embedding model we'll be using\n",
    "model = SentenceTransformer(\"all-mpnet-base-v2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5c18478-c9e5-44cd-a9b3-00c5abcd8c23",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up appropriate output folders\n",
    "plots_folder = './outputs/plots/'\n",
    "intermediate_plots_folder = os.path.join(plots_folder, 'intermediate_plots')\n",
    "\n",
    "if not os.path.exists(intermediate_plots_folder):\n",
    "    os.makedirs(intermediate_plots_folder)\n",
    "    print(f\"Created folder: {intermediate_plots_folder}\")\n",
    "else:\n",
    "    print(f\"Folder already exists: {intermediate_plots_folder}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5282a978-c98d-4167-950d-eed2d5ef1147",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# import CSV with tags\n",
    "\n",
    "tags_df = pd.read_csv('inputs/2-AFFORD_interviews_all_sentences.txt', sep='\\t')\n",
    "\n",
    "tags_df['content'] = tags_df['content'].apply(remove_speaker_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee89d003-bf00-48c6-9987-673669b001f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "tags = list(tags_df['tag'].unique())\n",
    "tags.sort()\n",
    "item = tags.pop(1)\n",
    "tags.append(item)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06809706-1527-41af-962f-26999e873fc6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(\"Generating embeddings per-category and plotting corresponding dendrograms...\\n\")\n",
    "\n",
    "embeddings_list = []\n",
    "linkage_matrices_list = []\n",
    "dendrogram_figures = []\n",
    "corpus_list = []\n",
    "\n",
    "for tag in tags:\n",
    "\n",
    "    # filter the dataframe to retrieve tags of a category\n",
    "    print(tag)\n",
    "    filtered_tags_df = tags_df.loc[tags_df['tag'].str.contains(tag, case=False, na=False)]\n",
    "    #print(len(filtered_tags_df))\n",
    "\n",
    "    # create embeddings by passing the corpus through the pre-trained model\n",
    "    corpus = filtered_tags_df['content'].tolist()\n",
    "    corpus_list.append(corpus)\n",
    "    embeddings = model.encode(corpus)\n",
    "    embeddings_list.append(embeddings)\n",
    "    print('\\tsize of embeddings array:', embeddings.shape)\n",
    "\n",
    "    # Generate linkage matrix\n",
    "    linkage_matrix = linkage(embeddings, method='ward')\n",
    "    linkage_matrices_list.append(linkage_matrix)\n",
    "\n",
    "    # Plot dendrogram\n",
    "    fig, ax = plt.subplots(figsize=(6, 5))\n",
    "    dendrogram(linkage_matrix, color_threshold=0, no_labels=True)\n",
    "    plt.title(f'{tag} ({embeddings.shape[0]} sentences)')\n",
    "    plt.xlabel('Sentences')\n",
    "    plt.ylabel('Distance')\n",
    "\n",
    "    # save the figure in the list\n",
    "    dendrogram_figures.append(fig)\n",
    "\n",
    "    # also save the figure as standalone\n",
    "    fig.savefig(f'./outputs/plots/intermediate_plots/{tag.replace(' ', '-').replace('.', '')}_dendrogram.png')\n",
    "    \n",
    "    plt.close(fig)\n",
    "    print('\\texported dendrogram\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cffbe85-a67d-4ad7-8b9d-7f5be8ab60b8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(\"Exporting composite figure of all dendrograms...\\n\")\n",
    "\n",
    "num_dendrograms = len(dendrogram_figures)\n",
    "num_cols = 3 \n",
    "num_rows = math.ceil(num_dendrograms / num_cols)\n",
    "\n",
    "fig, axes = plt.subplots(num_rows, num_cols, figsize=(18, 5 * num_rows))  # Adjust figsize as needed\n",
    "axes = axes.flatten()  # Flatten in case we have fewer dendrograms than slots in the grid\n",
    "\n",
    "\n",
    "for i, dendro_fig in enumerate(dendrogram_figures):\n",
    "    # Render each dendrogram in its respective subplot\n",
    "    axes[i].imshow(dendro_fig.canvas.buffer_rgba())\n",
    "    axes[i].axis('off')  # Turn off axis for cleaner display\n",
    "\n",
    "# turn off the remaining empty plots\n",
    "for j in range(i + 1, len(axes)):\n",
    "    axes[j].axis('off')\n",
    "\n",
    "plt.tight_layout()\n",
    "fig.savefig(f'./outputs/plots/Per-category_tags_dendrograms.png')\n",
    "plt.close(fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cb1276f-7cf8-416c-98ac-b9cb2c2bd396",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(\"Performing hierarchical clustering per-category and plotting corresponding dendrograms...\\n\")\n",
    "\n",
    "n_clusters_list = [4, 3, 3, 4, 3, 3, 4, 3, 3, 4]\n",
    "\n",
    "# initialize lists to store each threshold distance, dendrogram figure, and cluster assignments\n",
    "threshold_distances_list = []\n",
    "dendrogram_figures = []\n",
    "cluster_assignments_list = []\n",
    "\n",
    "for i in range(len(tags)):\n",
    "#for i in range(1):\n",
    "    print(tags[i])\n",
    "\n",
    "    corpus = corpus_list[i]\n",
    "    embeddings = embeddings_list[i]\n",
    "    linkage_matrix = linkage_matrices_list[i]\n",
    "    nclust = n_clusters_list[i]\n",
    "\n",
    "    # Perform hierarchical clustering\n",
    "    clustering_model = AgglomerativeClustering(\n",
    "        n_clusters=nclust, linkage='ward', distance_threshold=None)\n",
    "    clustering_model.fit(embeddings)\n",
    "\n",
    "    # retrieve assignments and distance at which the dendrogram was cut\n",
    "    cluster_assignment = clustering_model.labels_\n",
    "    cluster_assignments_list.append(cluster_assignment)\n",
    "    threshold_distance = linkage_matrix[-(clustering_model.n_clusters - 1), 2]\n",
    "    threshold_distances_list.append(threshold_distance)\n",
    "\n",
    "    # re-plot dendrogram\n",
    "    fig, ax = plt.subplots(figsize=(6, 5))\n",
    "    dendrogram(linkage_matrix, color_threshold=threshold_distance, no_labels=True)\n",
    "    plt.axhline(y=threshold_distance, color='black', linestyle=':')\n",
    "    plt.title(f'{tags[i]} ({embeddings.shape[0]} sentences)')\n",
    "    plt.xlabel('Sentences')\n",
    "    plt.ylabel('Distance')\n",
    "    \n",
    "    # save the figure in the list\n",
    "    dendrogram_figures.append(fig)\n",
    "\n",
    "    # also save to file\n",
    "    filename = f'./outputs/plots/intermediate_plots/{tags[i].replace(\" \", \"-\").replace(\".\", \"\")}_dendrogram_clustered.png'\n",
    "    fig.savefig(filename)\n",
    "\n",
    "    # close the figure to prevent display during the loop\n",
    "    plt.close(fig)  \n",
    "\n",
    "    print('\\texported dendrogram with clustering')\n",
    "\n",
    "    #print(len(corpus))\n",
    "    clustered_sentences = {}\n",
    "    for sentence_id, cluster_id in enumerate(cluster_assignment):\n",
    "        if cluster_id not in clustered_sentences:\n",
    "            clustered_sentences[cluster_id] = []\n",
    "    \n",
    "        clustered_sentences[cluster_id].append(corpus[sentence_id])\n",
    "\n",
    "    clusters_df = pd.DataFrame([(key, item) for key, values in clustered_sentences.items() for item in values], \n",
    "                  columns=['cluster', 'content'])\n",
    "    \n",
    "    clustered_df = pd.merge(tags_df, clusters_df, on='content', how='right')\n",
    "    clustered_df = clustered_df[(clustered_df['tag'] == tags[i])].drop_duplicates(subset='id')\n",
    "\n",
    "    #print(clustered_df.shape)\n",
    "    clustered_df.to_csv(f'outputs/{tags[i].replace(\" \", \"-\").replace(\".\", \"\")}_tags_clustered.csv', sep=',')\n",
    "\n",
    "    print('\\texported CSV table with clustered tags\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10c13a42-c0af-43b3-8f67-fd275e74b0be",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Exporting composite figure of all dendrograms with clustering...\\n\")\n",
    "\n",
    "fig, axes = plt.subplots(num_rows, num_cols, figsize=(18, 5 * num_rows))  # Adjust figsize as needed\n",
    "axes = axes.flatten()  # Flatten in case we have fewer dendrograms than slots in the grid\n",
    "\n",
    "\n",
    "for i, dendro_fig in enumerate(dendrogram_figures):\n",
    "    # Render each dendrogram in its respective subplot\n",
    "    axes[i].imshow(dendro_fig.canvas.buffer_rgba())\n",
    "    axes[i].axis('off')  # Turn off axis for cleaner display\n",
    "\n",
    "# turn off the remaining empty plots\n",
    "for j in range(i + 1, len(axes)):\n",
    "    axes[j].axis('off')\n",
    "\n",
    "plt.tight_layout()\n",
    "fig.savefig(f'./outputs/plots/Per-category_tags_clustered_dendrograms.png')\n",
    "plt.close(fig)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
